{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ff393bad",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff63d294",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment on Colab\n",
    "#!pip install dm-haiku\n",
    "#!pip install optax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ed8edd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import haiku as hk\n",
    "from functools import partial\n",
    "import math\n",
    "import numpy as np\n",
    "from numpy.random import random\n",
    "import optax\n",
    "import torch\n",
    "from typing import Any, Sequence"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d866ae3",
   "metadata": {},
   "source": [
    "# Playing with Jax/Haiku/Optax\n",
    "\n",
    "In this notebook, I am learning [JAX](https://jax.readthedocs.io/en/latest/index.html) and the libraries [Haiku](https://dm-haiku.readthedocs.io/en/latest/index.html) and [Optax](https://optax.readthedocs.io/en/latest/index.html) by reimplementing the linear regression example of [dataflowr module 2b](https://dataflowr.github.io/website/modules/2b-automatic-differentiation/),\n",
    "\n",
    "I found Sabrina Mielke’s post very useful: [From PyTorch to JAX: towards neural net frameworks that purify stateful code](https://sjmielke.com/jax-purify.htm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98f4bf80",
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate random input data\n",
    "x = random((30,2)).astype('float32')\n",
    "# generate labels corresponding to input data x\n",
    "y = np.dot(x, [2., -3.]) + 1.\n",
    "y = np.expand_dims(y, axis=1).astype('float32')\n",
    "w_source = np.array([2., -3.])\n",
    "b_source  = np.array([1.])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bb6d374",
   "metadata": {},
   "source": [
    "# Setup\n",
    "\n",
    "Our model is:\n",
    "$$\n",
    "y_t = 2x^1_t-3x^2_t+1, \\quad t\\in\\{1,\\dots,30\\}\n",
    "$$\n",
    "\n",
    "Our task is given the 'observations' $(x_t,y_t)_{t\\in\\{1,\\dots,30\\}}$ to recover the weights $w^1=2, w^2=-3$ and the bias $b = 1$.\n",
    "\n",
    "In order to do so, we will solve the following optimization problem:\n",
    "$$\n",
    "\\underset{w^1,w^2,b}{\\operatorname{argmin}} \\sum_{t=1}^{30} \\left(w^1x^1_t+w^2x^2_t+b-y_t\\right)^2\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d1e780e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# randomly initialize learnable weights and bias\n",
    "w_init = random(2)\n",
    "b_init = random(1)\n",
    "\n",
    "w = w_init\n",
    "b = b_init\n",
    "print(\"initial values of the parameters:\", w, b )\n",
    "\n",
    "dtype = torch.FloatTensor\n",
    "w_init_t = torch.from_numpy(w_init).type(dtype)\n",
    "b_init_t = torch.from_numpy(b_init).type(dtype)\n",
    "x_t = torch.from_numpy(x).type(dtype)\n",
    "y_t = torch.from_numpy(y).type(dtype)\n",
    "\n",
    "learning_rate = 1e-2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e85d7d45",
   "metadata": {},
   "source": [
    "# Pytorch version\n",
    "\n",
    "Pytorch implementation taken from [Dataflowr module 2b](https://dataflowr.github.io/website/modules/2b-automatic-differentiation/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2f4b903",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.nn.Sequential(torch.nn.Linear(2, 1),)\n",
    "\n",
    "for m in model.children():\n",
    "    m.weight.data = w_init_t.clone().unsqueeze(0)\n",
    "    m.bias.data = b_init_t.clone()\n",
    "\n",
    "loss_fn = torch.nn.MSELoss(reduction='sum')\n",
    "\n",
    "model.train()\n",
    "\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
    "\n",
    "for epoch in range(10):\n",
    "    y_pred = model(x_t)\n",
    "    loss = loss_fn(y_pred, y_t)\n",
    "    print(\"progress:\", \"epoch:\", epoch, \"loss\",loss.item())\n",
    "    # Zero gradients, perform a backward pass, and update the weights.\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "# After training\n",
    "print(\"estimation of the parameters:\")\n",
    "for param in model.parameters():\n",
    "    print(param)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23b14c59",
   "metadata": {},
   "source": [
    "# Jax linear layer with Haiku\n",
    "\n",
    "Playing with a linear layer... you first define your pyhton function and 'jaxify' it into a pure function thanks to the `hk.transfrom`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a30ee2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _linear(x, config):\n",
    "    return hk.Linear(config.size_out)(x)\n",
    "\n",
    "class toy_config:\n",
    "    size_out = 5\n",
    "\n",
    "linear = hk.without_apply_rng(hk.transform(lambda x: _linear(x, config=toy_config)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "011e4442",
   "metadata": {},
   "outputs": [],
   "source": [
    "rng_key = jax.random.PRNGKey(42)\n",
    "x_dummy = jax.random.normal(key=rng_key, shape=(1,7))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d3453d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = linear.init(rng=rng_key, x=x_dummy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78f346c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fed4bf74",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_in = jax.random.normal(key=rng_key, shape=(10,3,7))\n",
    "out = linear.apply(x=x_in, params= params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c7b0244",
   "metadata": {},
   "outputs": [],
   "source": [
    "out.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "740503b1",
   "metadata": {},
   "source": [
    "# Haiku: Linear layer with initialization\n",
    "\n",
    "I want to check my results by comparing them to the Pytorch implementation so I need to start with the same initial values for the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feeecad2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# inspired from https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/initializers.py#L51#L63\n",
    "class Init_jnparray(hk.initializers.Initializer):\n",
    "    def __init__(self, w: jnp.ndarray):\n",
    "        self.w = w\n",
    "\n",
    "    def __call__(self, shape: Sequence[int], dtype: Any) -> jnp.ndarray:\n",
    "        if self.w.shape != tuple(shape):\n",
    "            raise ValueError('Error in shape! w:', self.w.shape,' and shape:', shape)\n",
    "        return self.w.astype(dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29c6205a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    size_out = 1\n",
    "    w_source = jnp.array([w_init]).swapaxes(1,0)\n",
    "    b_source = jnp.array(b_init)\n",
    "\n",
    "\n",
    "def _linear(x, config):\n",
    "    return hk.Linear(config.size_out,w_init=Init_jnparray(config.w_source), b_init=Init_jnparray(config.b_source))(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ebe54d1",
   "metadata": {},
   "source": [
    "Note taht noting is random here (no random initialization)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "259207b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_dummy = jax.random.normal(key=rng_key, shape=(1,2))\n",
    "linear = hk.without_apply_rng(hk.transform(lambda x: _linear(x, config=config)))\n",
    "params = linear.init(x=x_dummy,rng=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6de29ce6",
   "metadata": {},
   "outputs": [],
   "source": [
    "params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8875475",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = linear.apply(x=x,params=params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e635937",
   "metadata": {},
   "outputs": [],
   "source": [
    "out.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1a48472",
   "metadata": {},
   "source": [
    "# Computing loss and gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40ac8afc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mse_loss(y_pred, y_t):\n",
    "    return jax.lax.integer_pow(y_pred - y_t,2).sum()\n",
    "\n",
    "mse_loss(out,y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0219a867",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(x_in, y_t, config):\n",
    "    return mse_loss(_linear(x=x_in, config=config),y_t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1077aa13",
   "metadata": {},
   "outputs": [],
   "source": [
    "hk_loss_fn = hk.without_apply_rng(hk.transform(partial(loss_fn, config=config)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25d7ae2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = hk_loss_fn.init(rng=rng_key, x_in=x,y_t=y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5af73cb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b08d9b80",
   "metadata": {},
   "source": [
    "Redefining the loss as the `apply` method of the transformed haiku loss and then the pytorch `backward` operation is done in jax with [`value_and_grad`](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#evaluate-a-function-and-its-gradient-using-value-and-grad)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24048610",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn = hk_loss_fn.apply\n",
    "loss, grads = jax.value_and_grad(loss_fn)(params,x_in=x,y_t=y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3049fd98",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5f82f4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "grads"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c418a83",
   "metadata": {},
   "source": [
    "# Optimizer\n",
    "\n",
    "Now things are rather easy to understand"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c4e02b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optax.sgd(learning_rate=1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e01f7107",
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_state = optimizer.init(params)\n",
    "for epoch in range(10):\n",
    "    loss, grads = jax.value_and_grad(loss_fn)(params,x_in=x,y_t=y)\n",
    "    print(\"progress:\", \"epoch:\", epoch, \"loss\",loss)\n",
    "    updates, opt_state = optimizer.update(grads, opt_state, params)\n",
    "    params = optax.apply_updates(params, updates)\n",
    "    \n",
    "# After training\n",
    "print(\"estimation of the parameters:\")\n",
    "print(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "939e0d37",
   "metadata": {},
   "source": [
    "# Jax/Haiku/Optax full code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9576481a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    size_out = 1\n",
    "    w_source = jnp.array([w_init]).swapaxes(1,0)\n",
    "    b_source = jnp.array(b_init)\n",
    "\n",
    "class Init_jnparray(hk.initializers.Initializer):\n",
    "    def __init__(self, w: jnp.ndarray):\n",
    "        self.w = w\n",
    "\n",
    "    def __call__(self, shape: Sequence[int], dtype: Any) -> jnp.ndarray:\n",
    "        if self.w.shape != tuple(shape):\n",
    "            raise ValueError('Error in shape! w:', self.w.shape,' and shape:', shape)\n",
    "        return self.w.astype(dtype)\n",
    "    \n",
    "def _linear(x, config):\n",
    "    return hk.Linear(config.size_out,w_init=Init_jnparray(config.w_source), b_init=Init_jnparray(config.b_source))(x)\n",
    "\n",
    "def mse_loss(y_pred, y_t):\n",
    "    return jax.lax.integer_pow(y_pred - y_t,2).sum()\n",
    "\n",
    "def loss_fn(x_in, y_t, config):\n",
    "    return mse_loss(_linear(x=x_in, config=config),y_t)\n",
    "\n",
    "hk_loss_fn = hk.without_apply_rng(hk.transform(partial(loss_fn, config=config)))\n",
    "params = hk_loss_fn.init(x_in=x,y_t=y,rng=None)\n",
    "loss_fn = hk_loss_fn.apply\n",
    "\n",
    "optimizer = optax.sgd(learning_rate=1e-2)\n",
    "\n",
    "opt_state = optimizer.init(params)\n",
    "for epoch in range(10):\n",
    "    loss, grads = jax.value_and_grad(loss_fn)(params,x_in=x,y_t=y)\n",
    "    print(\"progress:\", \"epoch:\", epoch, \"loss\",loss)\n",
    "    updates, opt_state = optimizer.update(grads, opt_state, params)\n",
    "    params = optax.apply_updates(params, updates)\n",
    "    \n",
    "# After training\n",
    "print(\"estimation of the parameters:\")\n",
    "print(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "735699aa",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax",
   "language": "python",
   "name": "jax"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
